{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Efficent Marginals Computation\n",
    "\n",
    "GTSAM can very efficiently calculate marginals in Bayes trees. In this post, we illustrate the “shortcut” mechanism for **caching** the conditional distribution $P(S \\mid R)$ in a Bayes tree, allowing efficient other marginal queries. We assume familiarity with **Bayes trees** from [the previous post](#)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Toy Example\n",
    "\n",
    "We create a small Bayes tree:\n",
    "\n",
    "\\begin{equation}\n",
    "P(a \\mid b) P(b,c \\mid r) P(f \\mid e) P(d,e \\mid r) P(r).\n",
    "\\end{equation}\n",
    "\n",
    "Below is some Python code (using GTSAM’s discrete wrappers) to define and build the corresponding Bayes tree. We'll use a discrete example, i.e., we'll create a `DiscreteBayesTree`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gtsam import DiscreteConditional, DiscreteBayesTree, DiscreteBayesTreeClique, DecisionTreeFactor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make discrete keys (key in elimination order, cardinality):\n",
    "keys = [(0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (5, 2), (6, 2)]\n",
    "names = {0: 'a', 1: 'f', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'r'}\n",
    "aKey, fKey, bKey, cKey, dKey, eKey, rKey = keys\n",
    "keyFormatter = lambda key: names[key]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. Root Clique: P(r)\n",
    "cliqueR = DiscreteBayesTreeClique(DiscreteConditional(rKey, \"0.4/0.6\"))\n",
    "\n",
    "# 2. Child Clique 1: P(b, c | r)\n",
    "cliqueBC = DiscreteBayesTreeClique(\n",
    "    DiscreteConditional(\n",
    "        2, DecisionTreeFactor([bKey, cKey, rKey], \"0.3 0.7 0.1 0.9 0.2 0.8 0.4 0.6\")\n",
    "    )\n",
    ")\n",
    "\n",
    "# 3. Child Clique 2: P(d, e | r)\n",
    "cliqueDE = DiscreteBayesTreeClique(\n",
    "    DiscreteConditional(\n",
    "        2, DecisionTreeFactor([dKey, eKey, rKey], \"0.1 0.9 0.9 0.1 0.2 0.8 0.3 0.7\")\n",
    "    )\n",
    ")\n",
    "\n",
    "# 4. Leaf Clique from Child 1: P(a | b)\n",
    "cliqueA = DiscreteBayesTreeClique(DiscreteConditional(aKey, [bKey], \"1/3 3/1\"))\n",
    "\n",
    "# 5. Leaf Clique from Child 2: P(f | e)\n",
    "cliqueF = DiscreteBayesTreeClique(DiscreteConditional(fKey, [eKey], \"1/3 3/1\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build the BayesTree:\n",
    "bayesTree = DiscreteBayesTree()\n",
    "\n",
    "# Insert root:\n",
    "bayesTree.insertRoot(cliqueR)\n",
    "\n",
    "# Attach child cliques to root:\n",
    "bayesTree.addClique(cliqueBC, cliqueR)\n",
    "bayesTree.addClique(cliqueDE, cliqueR)\n",
    "\n",
    "# Attach leaf cliques:\n",
    "bayesTree.addClique(cliqueA, cliqueBC)\n",
    "bayesTree.addClique(cliqueF, cliqueDE)\n",
    "\n",
    "# bayesTree.print(\"bayesTree\", keyFormatter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 12.0.0 (0)\n",
       " -->\n",
       "<!-- Title: G Pages: 1 -->\n",
       "<svg width=\"168pt\" height=\"188pt\"\n",
       " viewBox=\"0.00 0.00 167.97 188.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n",
       "<title>G</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-184 163.97,-184 163.97,4 -4,4\"/>\n",
       "<!-- 0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>0</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"79.49\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"79.49\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">r</text>\n",
       "</g>\n",
       "<!-- 1 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>1</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"35.49\" cy=\"-90\" rx=\"35.49\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"35.49\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">b, c : r</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;1 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>0&#45;&gt;1</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M69.5,-145.12C64.29,-136.82 57.78,-126.46 51.85,-117.03\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"54.83,-115.19 46.54,-108.59 48.9,-118.92 54.83,-115.19\"/>\n",
       "</g>\n",
       "<!-- 3 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>3</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"124.49\" cy=\"-90\" rx=\"35.49\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"124.49\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">d, e : r</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;3 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>0&#45;&gt;3</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M89.7,-145.12C95.09,-136.73 101.83,-126.24 107.94,-116.73\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"110.72,-118.88 113.18,-108.58 104.83,-115.1 110.72,-118.88\"/>\n",
       "</g>\n",
       "<!-- 2 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>2</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"35.49\" cy=\"-18\" rx=\"27.3\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"35.49\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">a : b</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;2 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>1&#45;&gt;2</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M35.49,-71.7C35.49,-64.41 35.49,-55.73 35.49,-47.54\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"38.99,-47.62 35.49,-37.62 31.99,-47.62 38.99,-47.62\"/>\n",
       "</g>\n",
       "<!-- 4 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>4</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"124.49\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"124.49\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">f : e</text>\n",
       "</g>\n",
       "<!-- 3&#45;&gt;4 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>3&#45;&gt;4</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M124.49,-71.7C124.49,-64.41 124.49,-55.73 124.49,-47.54\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"127.99,-47.62 124.49,-37.62 120.99,-47.62 127.99,-47.62\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.sources.Source at 0x10796f1a0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import graphviz\n",
    "graphviz.Source(bayesTree.dot(keyFormatter))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Naive Computation of P(a)\n",
    "The marginal $P(a)$ can be computed by summing out the other variables in the tree:\n",
    "$$\n",
    "P(a) = \\sum_{b, c, d, e, f, r} P(a, b, c, d, e, f, r)\n",
    "$$\n",
    "\n",
    "Using the Bayes tree structure, we have\n",
    "\n",
    "$$\n",
    "P(a) = \\sum_{b, c, d, e, f, r} P(a \\mid b) P(b, c \\mid r) P(f \\mid e) P(d, e \\mid r) P(r)  \n",
    "$$\n",
    "\n",
    "but we can ignore variables $e$ and $f$ not on the path from $a$ to the root $r$. Indeed, by associativity we have\n",
    "\n",
    "$$\n",
    "P(a) = \\sum_{r} \\Bigl\\{ \\sum_{e,f} P(f \\mid e) P(d, e \\mid r) \\Bigr\\} \\sum_{b, c, d} P(a \\mid b) P(b, c \\mid r) P(r)\n",
    "$$\n",
    "\n",
    "where the grouped terms sum to one for any value of $r$, and hence\n",
    "\n",
    "$$\n",
    "P(a) = \\sum_{r, b, c, d} P(a \\mid b) P(b, c \\mid r) P(r).\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Memoization via Shortcuts\n",
    "\n",
    "In GTSAM, we compute this recursively\n",
    "\n",
    "#### Step 1\n",
    "We want to compute the marginal via\n",
    "$$\n",
    "P(a) = \\sum_{r, b} P(a \\mid b) P(b).\n",
    "$$\n",
    "where $P(b)$ is the separator of this clique.\n",
    "\n",
    "#### Step 2\n",
    "To compute the separator marginal, we use the **shortcut** $P(b|r)$:\n",
    "$$\n",
    "P(b) = \\sum_{r} P(b \\mid r) P(r).\n",
    "$$\n",
    "In general, a shortcut $P(S|R)$ directly conditions this clique's separator $S$ on the root clique $R$, even if there are many other cliques in-between. That is why it is called a *shortcut*.\n",
    "\n",
    "#### Step 3 (optional)\n",
    "If the shortcut was already computed, then we are done! If not, we compute it recursively:\n",
    "$$\n",
    "P(S\\mid R) = \\sum_{F_p,\\,S_p \\setminus S}P(F_p \\mid S_p) P(S_p \\mid R).\n",
    "$$\n",
    "Above $P(F_p \\mid S_p)$ is the parent clique, and by the running intersection property we know that the seprator $S$ is a subset of the parent clique's variables.\n",
    "Note that the recursion is because we might not have $P(S_p \\mid R)$ yet, so it might have to be computed in turn, etc. The recursion ends at nodes below the root, and **after we have obtained $P(S\\mid R)$ we cache it**.\n",
    "\n",
    "In our example, the computation is simply\n",
    "$$\n",
    "P(b|r) = \\sum_{c} P(b, c \\mid r),\n",
    "$$\n",
    "because this the parent separator is already the root, so $P(S_p \\mid R)$ is omitted. This is also the end of the recursion.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "Marginal P(a):\n",
      " Discrete Conditional\n",
      " P( 0 ):\n",
      " Choice(0) \n",
      " 0 Leaf 0.51\n",
      " 1 Leaf 0.49\n",
      "\n",
      "\n",
      "3\n"
     ]
    }
   ],
   "source": [
    "# Marginal of the leaf variable 'a':\n",
    "print(bayesTree.numCachedSeparatorMarginals())\n",
    "marg_a = bayesTree.marginalFactor(aKey[0])\n",
    "print(\"Marginal P(a):\\n\", marg_a)\n",
    "print(bayesTree.numCachedSeparatorMarginals())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n",
      "Marginal P(b):\n",
      " Discrete Conditional\n",
      " P( 2 ):\n",
      " Choice(2) \n",
      " 0 Leaf 0.48\n",
      " 1 Leaf 0.52\n",
      "\n",
      "\n",
      "3\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Marginal of the internal variable 'b':\n",
    "print(bayesTree.numCachedSeparatorMarginals())\n",
    "marg_b = bayesTree.marginalFactor(bKey[0])\n",
    "print(\"Marginal P(b):\\n\", marg_b)\n",
    "print(bayesTree.numCachedSeparatorMarginals())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n",
      "Joint P(a, f):\n",
      " DiscreteBayesNet\n",
      " \n",
      "size: 2\n",
      "conditional 0:  P( 0 | 1 ):\n",
      " Choice(1) \n",
      " 0 Choice(0) \n",
      " 0 0 Leaf 0.51758893\n",
      " 0 1 Leaf 0.48241107\n",
      " 1 Choice(0) \n",
      " 1 0 Leaf 0.50222672\n",
      " 1 1 Leaf 0.49777328\n",
      "\n",
      "conditional 1:  P( 1 ):\n",
      " Choice(1) \n",
      " 0 Leaf 0.506\n",
      " 1 Leaf 0.494\n",
      "\n",
      "\n",
      "3\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Joint of leaf variables 'a' and 'f': P(a, f)\n",
    "# This effectively needs to gather info from two different branches\n",
    "print(bayesTree.numCachedSeparatorMarginals())\n",
    "marg_af = bayesTree.jointBayesNet(aKey[0], fKey[0])\n",
    "print(\"Joint P(a, f):\\n\", marg_af)\n",
    "print(bayesTree.numCachedSeparatorMarginals())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py312",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
